package main

import (
	"fmt"
	"os"
	"os/exec"
	"path"
	"text/template"

	"github.com/gomlx/gomlx/internal/backendparser"
	"github.com/gomlx/gomlx/internal/must"
	"github.com/gomlx/gomlx/pkg/support/sets"
	"k8s.io/klog/v2"
)

const (
	unaryOpsFile = "gen_unary_ops.go"
)

func IsUnaryOp(method backendparser.Method) bool {
	if len(method.Parameters) != 1 {
		return false
	}
	if method.Parameters[0].Type != "Op" {
		return false
	}
	if UnaryOpsToExclude.Has(method.Name) {
		return false
	}
	if len(method.Outputs) != 2 || method.Outputs[0].Type != "Op" || method.Outputs[1].Type != "error" {
		return false
	}
	return true
}

var (
	// UnaryOpAliases map GoMLX backend unary ops to StableHLO equivalent ops.
	UnaryOpAliases = map[string]string{
		"BitCount":   "Popcnt",
		"BitwiseNot": "Not",
		"Clz":        "CountLeadingZeros",
		"Cos":        "Cosine",
		"Exp":        "Exponential",
		"Expm1":      "ExponentialMinusOne",
		"Log1p":      "LogPlusOne",
		"LogicalNot": "Not",
		"Neg":        "Negate",
		"Round":      "RoundNearestEven",
		"Sin":        "Sine",
	}

	UnaryOpsToExclude = sets.MakeWith(
		"Identity", "Conj", "IsNaN",
		"Abs", // Abs is excluded because we special-handle values of complex types.
	)

	unaryOpsTemplate = template.Must(template.New(unaryOpsFile).Parse(`
/***** File generated by ./internal/cmd/stable_generator, based on the backends' ops. Don't edit it directly. *****/

package stablehlo

import (
	"github.com/gomlx/gomlx/backends"
	"github.com/gomlx/stablehlo"
)

{{- range .}}
{{- range .Method.Comments}}
{{.}}
{{- end}}
func (b *Builder) {{.Method.Name}}(operand backends.Op) (backends.Op, error) {
	nodes, err := b.verifyAndCastValues("{{.Method.Name}}", operand)
	if err != nil {
		return nil, err
	}
	value, err := stablehlo.{{.Alias}}(nodes[0].value)
	if err != nil {
		return nil, err
	}
	return b.newNode(value), nil
}
{{end}}
`))
)

func GenerateUnaryOps(methods []backendparser.Method) {
	var data []struct {
		Method backendparser.Method
		Alias  string
	}
	for _, method := range methods {
		if !IsUnaryOp(method) {
			continue
		}
		alias := method.Name
		if a, ok := UnaryOpAliases[method.Name]; ok {
			alias = a
		}
		data = append(data, struct {
			Method backendparser.Method
			Alias  string
		}{method, alias})
	}

	fileName := path.Join(must.M1(os.Getwd()), unaryOpsFile)
	f := must.M1(os.Create(fileName))
	must.M(unaryOpsTemplate.Execute(f, data))
	cmd := exec.Command("go", "fmt", fileName)
	klog.V(1).Infof("\t%s\n", cmd)
	cmd.Stderr = os.Stderr
	must.M(cmd.Run())
	fmt.Printf("✅ stablehlo_generator:  \tsuccessfully generated %s\n", fileName)
}
